MaxPoolGrad

描述 MaxPool 的反向传播(梯度)计算。该算子将上游梯度(dy)只回传到前向最大池化过程中被选为最大值的位置;其它位置的梯度为 0。

数学定义:

\[\begin{split}\text{output}_{b,\ h_i,\ w_i,\ c} = \begin{cases} \text{dy}_{b,\ h_o,\ w_o,\ c}, & \text{if } (h_i,\ w_i) = \displaystyle \arg\max_{(h,w)\in\mathcal{W}(h_o,w_o)} \text{input}_{b,\ h,\ w,\ c}, \\ 0, & \text{otherwise}. \end{cases}\end{split}\]

其中,\(\mathcal{W}(h_o, w_o)\) 表示输出位置 \((h_o, w_o)\) 对应的池化窗口区域。窗口像素位置 \((h, w)\) 可表示为:

\[h = h_o \cdot \text{stride}_h - \text{pad}_u + \Delta h\]
\[w = w_o \cdot \text{stride}_w - \text{pad}_l + \Delta w\]
\[\Delta h \in [0,\ \text{win}_h - 1], \qquad \Delta w \in [0,\ \text{win}_w - 1]\]

并且仅当采样点落在输入有效范围内时会被考虑:

\[0 \le h < \text{in}_h, \qquad 0 \le w < \text{in}_w.\]
实现细节说明:
  • 前向池化使用窗口 \(\text{win}_h \times \text{win}_w\),步长为 \(\text{stride}_h\), \(\text{stride}_w\),并且在边界处使用 pad(pad_u, pad_l)。

  • 反向传播时,输出梯度 tensor(即需要写入的输入梯度)在每个 batch 开始前先被初始化为 0(代码中有一次整体清零)。

  • 对于每个输出像素 \((h_o,w_o)\) 以及每个通道 c:

  • 在对应的输入窗口中找到前向最大值的位置 \((h^*,w^*)\)

  • 将上游梯度 \(\text{dy}_{b,h_o,w_o,c}\) 累加到该位置:\(\text{output}_{b,h^*,w^*,c} \mathrel{+}= \text{dy}_{b,h_o,w_o,c}\)

  • 其他位置梯度保持 0。

输入:
  • input - 输入张量指针,采用 NHWC 格式,形状为 \([batch,\ in\_h,\ in\_w,\ channel]\)

  • dy - 上游梯度张量指针,采用 NHWC 格式,形状为 \([batch,\ output\_h,\ output\_w,\ channel]\)

  • in_w - 输入张量的宽度 (W)

  • in_h - 输入张量的高度 (H)

  • win_w - 池化窗口的宽度,即窗口在 W 方向的大小

  • win_h - 池化窗口的高度,即窗口在 H 方向的大小

  • output_w - 输出特征图的宽度

  • output_h - 输出特征图的高度

  • batch - 批次大小,即输入中的 batch 数

  • channel - 通道数 C ,每个池化位置都分别对 C 个通道独立执行最大池化与裁剪

  • stride_w - 池化窗口在 W 方向的步长

  • stride_h - 池化窗口在 H 方向的步长

  • pad_l - 输入特征图左侧的填充大小

  • pad_u - 输入特征图上侧的填充大小

  • minf - 输出结果的下界值。池化结果会执行 \(\max(v,\ \text{minf})\)

  • maxf - 输出结果的上界值。池化结果会执行 \(\min(v,\ \text{maxf})\)

  • core_mask - 核心掩码,指定使用的计算核心

输出:
  • output - 输出张量指针,采用 NHWC 格式,形状为 \([batch,\ in\_h,\ in\_w,\ channel]\)

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持fp32, fp64

  • MT7004 支持fp16, fp32

  • 调用时将除 core_mask 外的参数打包通过 long long params 数组传入,顺序为: input, dy, output, in_w, in_h, win_w, win_h, output_w, output_h, batch, channel, stride_w, stride_h, pad_l, pad_u, minf, maxf

共享存储版本:

void hp_maxpool_grad_s(long long *params, int core_mask)
void fp_maxpool_grad_s(long long *params, int core_mask)
void dp_maxpool_grad_s(long long *params, int core_mask)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3
 4int main(int argc, char* argv[]) {
 5    double* input_ptr = (double*)0xA0000000;
 6    double* dy_ptr = (double*)0xB0000000;
 7    double* output_ptr = (double*)0xC0000000;
 8    double* check_ptr = (double*)0xD0000000;
 9    int in_w = gin_w;
10    int in_h = gin_h;
11    int win_w = 6;
12    int win_h = 6;
13    int batch = gbatch;
14    int channel = 2;
15    int stride_w = 4;
16    int stride_h = 4;
17    int pad_l = 1;
18    int pad_u = 1;
19    double minf = 0.0f;
20    double maxf = 50.0f;
21
22    // 根据标准公式计算输出尺寸
23    int dividor = in_w + pad_l*2 - win_w;
24    int output_w = (dividor + stride_w - 1) / stride_w + 1;
25    int dividor2 = in_h + pad_u*2 - win_h;
26    int output_h = (dividor2 + stride_h - 1) / stride_h + 1;
27
28    long long params[17];
29    params[0] = (long long)input_ptr;
30    params[1] = (long long)dy_ptr;
31    params[2] = (long long)output_ptr;
32    params[3] = (long long)in_w;
33    params[4] = (long long)in_h;
34    params[5] = (long long)win_w;
35    params[6] = (long long)win_h;
36    params[7] = (long long)output_w;
37    params[8] = (long long)output_h;
38    params[9] = (long long)batch;
39    params[10] = (long long)channel;
40    params[11] = (long long)stride_w;
41    params[12] = (long long)stride_h;
42    params[13] = (long long)pad_l;
43    params[14] = (long long)pad_u;
44    params[15] = (long long)&minf; //注意这里传指针,不能直接强制转换成long long
45    params[16] = (long long)&maxf;
46    int core_mask = 0x0f;
47    fp_maxpool_grad_s(params, core_mask);
48    return 0;
49}

私有存储版本:

void hp_maxpool_grad_p(long long *params)
void fp_maxpool_grad_p(long long *params)
void dp_maxpool_grad_p(long long *params)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3
 4int main(int argc, char* argv[]) {
 5    double* input_ptr = (double*)0xA0000000;
 6    double* dy_ptr = (double*)0xB0000000;
 7    double* output_ptr = (double*)0xC0000000;
 8    double* check_ptr = (double*)0xD0000000;
 9    int in_w = gin_w;
10    int in_h = gin_h;
11    int win_w = 6;
12    int win_h = 6;
13    int batch = gbatch;
14    int channel = 2;
15    int stride_w = 4;
16    int stride_h = 4;
17    int pad_l = 1;
18    int pad_u = 1;
19    double minf = 0.0f;
20    double maxf = 50.0f;
21
22    // 根据标准公式计算输出尺寸
23    int dividor = in_w + pad_l*2 - win_w;
24    int output_w = (dividor + stride_w - 1) / stride_w + 1;
25    int dividor2 = in_h + pad_u*2 - win_h;
26    int output_h = (dividor2 + stride_h - 1) / stride_h + 1;
27
28    long long params[17];
29    params[0] = (long long)input_ptr;
30    params[1] = (long long)dy_ptr;
31    params[2] = (long long)output_ptr;
32    params[3] = (long long)in_w;
33    params[4] = (long long)in_h;
34    params[5] = (long long)win_w;
35    params[6] = (long long)win_h;
36    params[7] = (long long)output_w;
37    params[8] = (long long)output_h;
38    params[9] = (long long)batch;
39    params[10] = (long long)channel;
40    params[11] = (long long)stride_w;
41    params[12] = (long long)stride_h;
42    params[13] = (long long)pad_l;
43    params[14] = (long long)pad_u;
44    params[15] = (long long)&minf; //注意这里传指针,不能直接强制转换成long long
45    params[16] = (long long)&maxf;
46    fp_maxpool_grad_p(params);
47    return 0;
48}